use crate::cons;

use super::IIorF;
use super::Matrix;
use std::ops;

impl Matrix<f64> {
    pub fn det(&self) -> f64 {
        if self.row() != self.col() {
            return 0.0;
        }

        let mut result = 0.0;
        if self.row() == 1 {
            result = self[(0, 0)];
        } else if self.row() == 2 {
            result = self[(0, 0)] * self[(1, 1)] - self[(0, 1)] * self[(1, 0)];
        } else {
            for j in 0..self.col() {
                let mut sub_mat = Matrix::new(self.row() - 1, self.col() - 1);
                for i in 1..self.row() {
                    for k in 0..self.col() {
                        if k < j {
                            sub_mat[((i - 1) as isize, k as isize)] =
                                self[(i as isize, k as isize)];
                        } else if k > j {
                            sub_mat[((i - 1) as isize, (k - 1) as isize)] =
                                self[(i as isize, k as isize)];
                        }
                    }
                }
                result +=
                    (if j % 2 == 0 { 1.0 } else { -1.0 }) * self[(0, j as isize)] * sub_mat.det();
            }
        }
        result
    }

    pub fn cofactor(&self, i: usize, j: usize) -> f64 {
        let mut sub_mat = Matrix::new(self.row() - 1, self.col() - 1);
        for k in 0..self.row() {
            for l in 0..self.col() {
                if k < i && l < j {
                    sub_mat[(k as isize, l as isize)] = self[(k as isize, l as isize)];
                } else if k < i && l > j {
                    sub_mat[(k as isize, (l - 1) as isize)] = self[(k as isize, l as isize)];
                } else if k > i && l < j {
                    sub_mat[((k - 1) as isize, l as isize)] = self[(k as isize, l as isize)];
                } else if k > i && l > j {
                    sub_mat[((k - 1) as isize, (l - 1) as isize)] = self[(k as isize, l as isize)];
                }
            }
        }
        sub_mat.det() * (if (i + j) % 2 == 0 { 1.0 } else { -1.0 })
    }

    pub fn adjoint(&self) -> Matrix<f64> {
        let mut adj = Matrix::new(self.row(), self.col());
        for i in 0..self.row() {
            for j in 0..self.col() {
                adj[(i as isize, j as isize)] = self.cofactor(i, j);
            }
        }
        adj
    }
}

impl ops::Add for &Matrix<f64> {
    type Output = Result<Matrix<f64>, bool>;

    fn add(self, rhs: Self) -> Self::Output {
        if self.row() != rhs.row() || self.col() != rhs.col() {
            return Err(false);
        }

        let mut result = Matrix::new(self.row(), self.col());
        for i in 0..self.row() {
            for j in 0..self.col() {
                let row_ind = i as isize;
                let col_ind = j as isize;
                result[(row_ind, col_ind)] = self[(row_ind, col_ind)] + rhs[(row_ind, col_ind)];
            }
        }
        Ok(result)
    }
}

impl ops::Sub for &Matrix<f64> {
    type Output = Result<Matrix<f64>, bool>;

    fn sub(self, rhs: Self) -> Self::Output {
        if self.row() != rhs.row() || self.col() != rhs.col() {
            return Err(false);
        }

        let mut result = Matrix::new(self.row(), self.col());
        for i in 0..self.row() {
            for j in 0..self.col() {
                let row_ind = i as isize;
                let col_ind = j as isize;
                result[(row_ind, col_ind)] = self[(row_ind, col_ind)] - rhs[(row_ind, col_ind)];
            }
        }
        Ok(result)
    }
}

impl ops::Mul for &Matrix<f64> {
    type Output = Result<Matrix<f64>, bool>;

    fn mul(self, rhs: Self) -> Self::Output {
        if self.col() != rhs.row() {
            return Err(false);
        }

        let mut result = Matrix::new(self.row(), rhs.col());
        for i in 0..self.row() {
            for j in 0..rhs.col() {
                let row_ind = i as isize;
                let col_ind = j as isize;
                let mut sum = 0.0;

                for k in 0..self.col() {
                    sum += self[(row_ind, k as isize)] * rhs[(k as isize, col_ind)];
                }
                result[(row_ind, col_ind)] = sum;
            }
        }
        Ok(result)
    }
}

impl<T> ops::Mul<T> for Matrix<f64>
where
    T: IIorF,
{
    type Output = Self;

    fn mul(self, rhs: T) -> Self::Output {
        let mut res = Matrix::new(self.row(), self.col());
        for i in 0..self.row() {
            for j in 0..self.col() {
                let row_ind = i as isize;
                let col_ind = j as isize;
                res[(row_ind, col_ind)] = self[(row_ind, col_ind)] * rhs.to_f64();
            }
        }
        return res;
    }
}

impl ops::BitXor<i32> for &Matrix<f64> {
    // add code here
    type Output = Result<Matrix<f64>, bool>;

    fn bitxor(self, rhs: i32) -> Self::Output {
        // 计算 A 的 N 次方
        if rhs < -1 {
            return Err(false);
        } else if rhs == -1 {
            // 计算矩阵的逆
            if self.shape().0 != self.shape().1 {
                return Err(false);
            } else if self.det() == 0.0 {
                return Err(false);
            } else {
                // 最常规情况, 没有特殊, 为可逆矩阵, 计算逆矩阵
                let mut res = Matrix::new(self.shape().0, self.shape().1);
                let det = self.det();
                for i in 0..self.shape().0 {
                    for j in 0..self.shape().1 {
                        res[(i as isize, j as isize)] = self.cofactor(i, j) / det;
                    }
                }
                return Ok(res);
            }
        } else {
            // 计算 A 的 rhs 次方
            if self.shape().0 != self.shape().1 {
                return Err(false);
            } else {
                // 方阵
                let mut res = cons::cons_mat::eye(rhs as usize);

                for _ in 0..rhs {
                    res = (&res * &self).unwrap();
                }
                return Ok(res);
            }
        }
    }
}
